import numpy as np
import torch
AVAILABLE = 0
PROCESSED = 1
COMPLETE = 3
FUTURE = 2

def bsearch(a, left, right, x) :
    while left <= right :
        mid = (left + right) // 2        
        if a[mid] == x:
            return mid 
        if a[mid] < x:
            left = mid + 1 
        else:
            right = mid - 1         
    return -1

class Graph:
    def __init__(self, args, job_num, machine_num, op_num=0):

        if op_num == 0: # train mode
            self.op_op_edge_idx = torch.zeros(size=(job_num * machine_num, job_num * machine_num), dtype=torch.float32)
            self.op_m_edge_idx = torch.zeros(size=(job_num * machine_num, machine_num), dtype=torch.float32)
        else:
            self.op_op_edge_idx = torch.zeros(size=(op_num, op_num), dtype=torch.float32)
            self.op_m_edge_idx = torch.zeros(size=(op_num, machine_num), dtype=torch.float32)
        self.op_x = []
        self.m_x = []

        self.args = args
        self.job_num = job_num
        self.machine_num = machine_num
        self.op_num = 0
        self.op_unfinished = {}

        self.m_m_edge_idx = torch.eye(machine_num, dtype=torch.float32)

    def get_data(self):
        data = {
            "op_op_idx" : self.op_op_edge_idx,
            "op_m_idx" : self.op_m_edge_idx,
            "m_m_idx" : self.m_m_edge_idx,
            "op_x" : self.op_x,
            "m_x" : self.m_x
        }
        return data
       
    def add_job(self, job):
        self.op_op_edge_idx[self.op_num : self.op_num + job.op_num, self.op_num : self.op_num + job.op_num] = 1
        for i in range(job.op_num):
            job.operations[i].node_id = self.op_num 
            op = job.operations[i]
            for mach_and_ptime in op.machine_and_processtime:
                self.op_m_edge_idx[self.op_num, mach_and_ptime[0]] = mach_and_ptime[1]
            self.op_num += 1

    def update_feature(self, jobs, machines, current_time, max_process_time):
        self.op_x, self.m_x = [], []
        for job in jobs:
            for op in job.operations:
                if self.args.delete_node == True:
                    idx = bsearch(self.op_unfinished, 0, len(self.op_unfinished) - 1, op.node_id)
                    if idx == -1:
                        continue
                    status = op.get_status(current_time)
                    if status == COMPLETE:
                        self.update_graph(idx)
                        self.op_unfinished.remove(op.node_id)
                        continue
                    else:
                        feat = [0] * 3
                else:
                    status = op.get_status(current_time)
                    feat = [0] * 4
                feat[status] = 1
                if status == AVAILABLE or status == FUTURE:
                    feat.append(5.0 / max_process_time)
                elif status == PROCESSED:
                    feat.append((op.finish_time - current_time) / max_process_time)
                else:
                    feat.append(0)
                feat.append(job.acc_expected_process_time[op.op_id] / job.acc_expected_process_time[0])
                self.op_x.append(feat) 
        for m in machines:
            feat = [0] * 2
            status = m.get_status(current_time)
            feat[status] = 1
            if status == AVAILABLE:
                feat.append(0)
            else:
                feat.append((m.avai_time() - current_time) / max_process_time)
            self.m_x.append(feat)
        self.op_x = torch.Tensor(self.op_x)
        self.m_x = torch.Tensor(self.m_x)
    
    def update_graph(self, idx):
        self.op_op_edge_idx = torch.cat((self.op_op_edge_idx[0:idx, :], self.op_op_edge_idx[idx + 1:, :]), dim=0)
        self.op_op_edge_idx = torch.cat((self.op_op_edge_idx[:, 0:idx], self.op_op_edge_idx[:, idx + 1:]), dim=1)
        self.op_m_edge_idx = torch.cat((self.op_m_edge_idx[0:idx, :], self.op_m_edge_idx[idx + 1:, :]), dim=0)